-
Notifications
You must be signed in to change notification settings - Fork 356
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Add report_progress
to TrainContext
#9826
Conversation
✅ Deploy Preview for determined-ui canceled.
|
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #9826 +/- ##
==========================================
- Coverage 54.38% 54.37% -0.01%
==========================================
Files 1261 1261
Lines 155770 155787 +17
Branches 3540 3540
==========================================
- Hits 84711 84709 -2
- Misses 70921 70940 +19
Partials 138 138
Flags with carried forward coverage won't be shown. Click here to find out more.
|
report_progress
to TrainContext
harness/determined/core/_train.py
Outdated
|
||
The ``progress`` should be the actual progress. | ||
""" | ||
logger.debug("report_progress()") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's put the progress value in the log statement
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also, we should probably validate > 0 < 1 here too
@@ -260,6 +260,19 @@ def report_early_exit(self, reason: EarlyExitReason) -> None: | |||
if r.status_code == 400: | |||
logger.warn("early exit has already been reported for this trial, ignoring new value") | |||
|
|||
def report_progress(self, progress: float) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this method also needs to be in DummyTrainContext
. otherwise local training will not work
can show accurate progress to users. | ||
|
||
The ``progress`` should be the actual progress. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: wording, and style. i know we're not consistent with style for docstrings in this class, but let's do it for new methods. we mostly try to follow the google style guide
suggestion:
"""
Report training progress to the master.
This is optional for training, but will be used by the WebUI to render completion status.
Progress must be reported as a float between 0 and 1.0, where 1.0 is 100% completion. It
should represent the current iteration step as a fraction of maximum training steps
(i.e.: `report_progress(step_num / max_steps)`).
Note that for hyperparameter search, progress should be reported through
``SearcherOperation.report_progress()`` in the Searcher API instead.
Arguments:
progress (float): completion progress in the range [0, 1.0].
"""
@@ -1406,6 +1406,7 @@ func (a *apiServer) ReportTrialProgress( | |||
msg := experiment.TrialReportProgress{ | |||
RequestID: rID, | |||
Progress: searcher.PartialUnits(req.Progress), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if it's not too much effort, it' be great if we could get rid of this searcher.PartialUnits
type, it's just a float anyway. will need to do it sooner or later, but not strictly necessary as part of this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since it touches multiple files and not directly related to this PR, I created a ticket for it
master/internal/experiment.go
Outdated
} | ||
if progress < 0 || progress > 1 { | ||
e.syslog.Errorf("Invalid progress value: %f", progress) | ||
return nil |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should return an error, otherwise we'll return a 200 in the HTTP POST
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, but I'm curious why the DB error next line is ignored?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it should also be returned, feel free to update it if you want. the only reason i can think of to not return it is if you have a one-off DB connection failure or something and you don't want to exit training because progress reporting isn't important.
this case is important tho IMO because if progress isn't [0,1] then that means user code is wrong and they should know to fix it.
@@ -482,6 +485,22 @@ def test_core_api_distributed_tutorial() -> None: | |||
) | |||
|
|||
|
|||
@pytest.mark.e2e_cpu |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this feature is logically simple enough that i don't think it's worth the time/resource/maintenence cost to have an e2e test, tbh.
if you wanted you could probably add a unit test for it, but i also think that's overkill, since it's just hitting an existing API.
@@ -24,6 +24,8 @@ def main(core_context, increment_by): | |||
core_context.train.report_training_metrics( | |||
steps_completed=steps_completed, metrics={"x": x} | |||
) | |||
# NEW: report training progress. | |||
core_context.train.report_progress(batch/100.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- the denominator is confusing here, because it looks like we're calculating out of 100% or something. would be better if we defined
max_length=100
as a variable outside this loop, and used it in thefor batch in range(max_length)
and also here. - shouldn't be batch, it should be
steps_completed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- include this change in
2_checkpoints.py
, too. they're meant to be incremental tutorials, hence the "# NEW: ..." - also update detached mode tutorials (and make sure this works in detached mode. it should, but just in case)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this function would work with detached mode out of the box, because in the existing method, we retrieve experiment from experiment.ExperimentRegistry
, but currently we do not include unmanaged experiments in ExperimentRegistry
. So we can either choose to include unmanaged experiment in ExperimentRegistry
, or retrieve unmanaged experiment from DB instead.
harness/determined/core/_train.py
Outdated
@@ -312,6 +336,9 @@ def upload_tensorboard_files( | |||
def report_early_exit(self, reason: EarlyExitReason) -> None: | |||
logger.info(f"report_early_exit({reason})") | |||
|
|||
def report_progress(self, progress: float) -> None: | |||
logger.info(f"report_progres with progress={progress}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: typo
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a few other small things for technical correctness:
- the check
if progress < 0 || progress > 1 {
return errors.Errorf("Invalid progress value: %f", progress)
}
should be moved up to (a *apiServer) ReportTrialProgress
since now we're saving to db for unmanaged exps upfront.
(a *apiServer) PatchTrial
should be updated to set progress to 1.0 when unmanaged trials exit. currently if your last progress report was .90, it'd never reach 100%. not sure what this means for the web ui, but not a big deal.
nice work! thanks for helping out with this ticket! 🙏 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Backend LGTM
Ticket
MD-499
Description
Add
report_progress
toTrainContext
Test Plan
Find an experiment using
report_progress
fromTrainContext
, such asexamples/tutorials/core_api/1_metrics.py
.Start the experiment, and while the experiment is running, monitor the experiment to verify the progress changes
Checklist
docs/release-notes/
See Release Note for details.